import os
from typing import List, Optional

import torch
import torch._inductor.config as inductor_config
from torch._functorch.aot_autograd import aot_module_simplified
from torch._inductor.compile_fx import compile_fx, select_decomp_table
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses import FakeTensor
from torch.fx import GraphModule

import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm.mapping import Mapping

from .multi_stream.auto_multi_stream import multi_stream_schedule
from .patterns.ar_residual_norm import register_ar_fusions
from .patterns.residual_add_norm import register_add_norm
from .piecewise_optimizer import piecewise_optimizer
from .recover_pass import recover_pass
from .remove_copy_pass import remove_copy_for_mutates_args


class Backend:

    _custom_pass_instances: List[PatternMatcherPass] = None
    _graph_pool_handle: tuple[int, int] = None

    # Following classes are used to let weakref ref the stream and eventlist objects.
    class Streams(list):
        pass

    class Events(list):
        pass

    def __init__(
        self,
        enable_inductor=True,
        enable_userbuffers=False,
        enable_piecewise_cuda_graph: bool = False,
        capture_num_tokens: Optional[List[int]] = None,
        max_num_streams: int = 1,
        mapping=None,
    ) -> None:
        super().__init__()
        self.elapsed_time = 0
        self.module_inference_event = []
        self.module_inference_time = 0
        self.call_count = 0
        self.mapping = mapping
        self.custom_passes = Backend.get_custom_pass(enable_userbuffers,
                                                     mapping)
        self.rank = tensorrt_llm.mpi_rank()
        self.enable_inductor = enable_inductor
        self.capture_num_tokens = sorted(capture_num_tokens or [])
        self.piecewise_cuda_graph = enable_piecewise_cuda_graph
        self.no_optimization = False
        self.num_streams = max_num_streams
        self.events = Backend.Events()
        inductor_config.enable_auto_functionalized_v2 = False

        if Backend._graph_pool_handle is None:
            Backend._graph_pool_handle = torch.cuda.graph_pool_handle()

        self.match_count = []

    @classmethod
    def get_custom_pass(cls, enable_userbuffers, mapping: Mapping):
        world_size = tensorrt_llm.mpi_world_size()
        if not cls._custom_pass_instances:
            # Really naive pass manager here
            cls._custom_pass_instances = [PatternMatcherPass()]
            if world_size > 1:
                # Currently torch compile cannot work properly with lamport fusion kernel
                # TO-DO: Fix this issue
                os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
                ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
                )
                register_ar_fusions(cls._custom_pass_instances, mapping,
                                    ub_enabled)
            else:
                register_add_norm(cls._custom_pass_instances[0])
        return cls._custom_pass_instances

    def bypass_optimization(self):
        self.no_optimization = True

    def enable_optimization(self):
        self.no_optimization = False

    def generate_events(self, num_events: int):
        if num_events > len(self.events):
            self.events += [
                torch.cuda.Event() for _ in range(num_events - len(self.events))
            ]

    def optimize(
        self,
        gm: GraphModule,
        example_inputs: List[torch.Tensor],
    ):
        graph = gm.graph
        for custom_pass in self.custom_passes:
            self.match_count.append(custom_pass.apply(graph))
            while self.match_count[-1]:
                self.match_count.append(custom_pass.apply(graph))
        graph.eliminate_dead_code()
        # After this pass, cannot run any dce!!!
        remove_copy_for_mutates_args(graph)

        # Do not apply multi-stream if enable piecewise cuda graph or inductor
        # For piecewise cuda graph, we will apply the multi-stream optimization in piecewise_optimizer
        # For inductor, we do not control the passes inside inductor.
        if self.num_streams > 1 and not self.piecewise_cuda_graph and not self.enable_inductor:
            num_events = multi_stream_schedule(gm, self.num_streams)
            self.generate_events(num_events)

        gm.recompile()

        if self.piecewise_cuda_graph:
            gm, num_events = piecewise_optimizer(
                gm,
                example_inputs,
                self.enable_inductor,
                self.input_num_tokens,
                self.capture_num_tokens,
                self._graph_pool_handle,
                self.num_streams,
            )
            self.generate_events(num_events)
            return gm
        elif self.enable_inductor:
            return compile_fx(gm, example_inputs)
        else:
            return gm

    def __call__(self, gm: GraphModule,
                 example_inputs: List[torch.Tensor]) -> callable:

        if self.no_optimization:
            logger.warning(
                "Bypassing torch.compile optimization and fallback to eager execution!"
            )
            return gm

        for node in gm.graph.nodes:
            if node.op == "placeholder":
                if node.name == "l_input_ids_":
                    example_value = node.meta["example_value"]
                    assert isinstance(example_value, FakeTensor)
                    self.input_num_tokens = example_value.shape[0]
                    break

        if self.piecewise_cuda_graph:
            assert (
                self.input_num_tokens is not None
            ), "Cannot detect input_num_tokens. Cannot use piecewise CUDA graph. What is the name of `input_ids`?"

        gm = recover_pass(gm)

        return aot_module_simplified(
            gm,
            example_inputs,
            fw_compiler=self.optimize,
            decompositions=select_decomp_table(),
        )
